Skip to content

[Kernels][FI] Skip trtllm attention when num_kv_heads=1#30842

Merged
vllm-bot merged 2 commits intovllm-project:mainfrom
yeqcharlotte:fi_kvhead1
Dec 17, 2025
Merged

[Kernels][FI] Skip trtllm attention when num_kv_heads=1#30842
vllm-bot merged 2 commits intovllm-project:mainfrom
yeqcharlotte:fi_kvhead1

Conversation

@yeqcharlotte
Copy link
Copy Markdown
Collaborator

@yeqcharlotte yeqcharlotte commented Dec 17, 2025

Purpose

We got the following error when running a small model on blackwell

[WORKER]:  File "/redacted/path/executor/abstract.py", line 116, in initialize_from_config
[WORKER]:    self.collective_rpc("compile_or_warm_up_model")
[WORKER]:  File "/redacted/path/executor/uniproc_executor.py", line 75, in collective_rpc
[WORKER]:    result = run_method(self.driver_worker, method, args, kwargs)
[WORKER]:  File "/redacted/path/serial_utils.py", line 460, in run_method
[WORKER]:    return func(*args, **kwargs)
[WORKER]:  File "/redacted/path/worker/gpu_worker.py", line 444, in compile_or_warm_up_model
[WORKER]:    kernel_warmup(self)
[WORKER]:  File "/redacted/path/model_executor/warmup/kernel_warmup.py", line 68, in kernel_warmup
[WORKER]:    worker.model_runner._dummy_run(
[WORKER]:  File "/redacted/env/utils/_contextlib.py", line 116, in decorate_context
[WORKER]:    return func(*args, **kwargs)
[WORKER]:  File "/redacted/path/worker/gpu_model_runner.py", line 4130, in _dummy_run
[WORKER]:    outputs = self.model(
[WORKER]:  File "/redacted/path/compilation/cuda_graph.py", line 220, in __call__
[WORKER]:    return self.runnable(*args, **kwargs)
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1767, in _wrapped_call_impl
[WORKER]:    return self._call_impl(*args, **kwargs)
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1778, in _call_impl
[WORKER]:    return forward_call(*args, **kwargs)
[WORKER]:  File "/redacted/path/model.py", line 664, in forward
[WORKER]:    transformer_output, _ = self.model_core(
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1767, in _wrapped_call_impl
[WORKER]:    return self._call_impl(*args, **kwargs)
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1778, in _call_impl
[WORKER]:    return forward_call(*args, **kwargs)
[WORKER]:  File "/redacted/path/model/transformer.py", line 1804, in forward
[WORKER]:    h, cache = self.transformer_layers_forward(
[WORKER]:  File "/redacted/path/model/transformer.py", line 1930, in transformer_layers_forward
[WORKER]:    h, new_cache = layer_fn(
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1767, in _wrapped_call_impl
[WORKER]:    return self._call_impl(*args, **kwargs)
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1778, in _call_impl
[WORKER]:    return forward_call(*args, **kwargs)
[WORKER]:  File "/redacted/path/model/transformer.py", line 1103, in forward
[WORKER]:    residual_stream, new_cache = self.pre_feed_forward_processing(
[WORKER]:  File "/redacted/path/model/transformer.py", line 1045, in pre_feed_forward_processing
[WORKER]:    attn_out, new_cache = self.attention(  # Pre-norm applied inside attention
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1767, in _wrapped_call_impl
[WORKER]:    return self._call_impl(*args, **kwargs)
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1778, in _call_impl
[WORKER]:    return forward_call(*args, **kwargs)
[WORKER]:  File "/redacted/path/model/transformer.py", line 572, in forward
[WORKER]:    return self._attention_forward(
[WORKER]:  File "/redacted/path/model/transformer.py", line 704, in _attention_forward
[WORKER]:    output = self.attention(
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1767, in _wrapped_call_impl
[WORKER]:    return self._call_impl(*args, **kwargs)
[WORKER]:  File "/redacted/env/nn/modules/module.py", line 1778, in _call_impl
[WORKER]:    return forward_call(*args, **kwargs)
[WORKER]:  File "/redacted/path/model/custom_op.py", line 493, in forward
[WORKER]:    return method(*args, **kwargs)
[WORKER]:  File "/redacted/path/model/layers/paged_attention.py", line 268, in forward_native
[WORKER]:    output_vllm = self._attention.forward(query, key, value)
[WORKER]:  File "/redacted/path/attention/layer.py", line 367, in forward
[WORKER]:    torch.ops.vllm.unified_attention_with_output(
[WORKER]:  File "/redacted/env/_ops.py", line 1208, in __call__
[WORKER]:    return self._op(*args, **(kwargs or {}))
[WORKER]:  File "/redacted/path/attention/utils/kv_transfer_utils.py", line 39, in wrapper
[WORKER]:    return func(*args, **kwargs)
[WORKER]:  File "/redacted/path/attention/layer.py", line 869, in unified_attention_with_output
[WORKER]:    self.impl.forward(
[WORKER]:  File "/redacted/path/attention/backends/flashinfer.py", line 1295, in forward
[WORKER]:    trtllm_batch_context_with_kv_cache(
[WORKER]:  File "/redacted/env/flashinfer/prefill.py", line 3513, in trtllm_batch_context_with_kv_cache
[WORKER]:    run_func(
[WORKER]:  File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
[WORKER]:RuntimeError: Error in function 'buildNdTmaDescriptor' at /workspace/include/flashinfer/trtllm/fmha/kernelParams.h:536: Check failed: false

Detailed errors

Error: Failed to initialize the TMA descriptor due to invalid argument
tmaFormat: 9 dim: 4 gmem: 0xf64d740000
Shape: 64 16 1 7721128 7149852580857340533
Stride: 128 2 4096 337
tileShapes: 64 16 1 1 1024
tileStrides: 1 1 1 1 1969767282
swizzleType: 3

We confirmed that the dummy run failed when num_kv_heads=1 which causes stride of 0. in those case we can fall back to flashinfer's native attention instead of trtllm. Existing test cases ignore this any way.

Test Plan

 pytest -v tests/kernels/attention/test_flashinfer_trtllm_attention.py::test_trtllm_attention_rejects_num_kv_heads_1

Test Result

tests/kernels/attention/test_flashinfer_trtllm_attention.py::test_trtllm_attention_rejects_num_kv_heads_1 PASSED [100%]

========================================= warnings summary ==========================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================== 1 passed, 2 warnings in 18.72s ===================================

Our internal run succeed on this. Remaining will depend on CI.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a runtime error that occurs when using TRTLLM attention with num_kv_heads=1. The fix involves adding checks to disable this configuration and fall back to FlashInfer's native attention, which is the correct approach. A corresponding test case has been added to ensure this behavior is enforced. The changes are logical and well-implemented. I have one suggestion to improve the clarity and maintainability of the logic in can_use_trtllm_attention.

Comment on lines +308 to +319
# num_kv_heads=1 is not supported due to TMA descriptor building limitations.
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if has_trtllm and num_kv_heads == 1:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return has_trtllm and (num_qo_heads % num_kv_heads == 0) and (num_kv_heads != 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to handle num_kv_heads=1 is correct, but its implementation could be simplified for better readability and maintainability. The current structure separates the warning log from the return logic, making it slightly convoluted. By using an early return for the num_kv_heads == 1 case, we can make the function's control flow more direct and easier to follow.

    # num_kv_heads=1 is not supported due to TMA descriptor building limitations.
    # When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
    # stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
    # TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
    # See: https://fburl.com/352mrydz
    if num_kv_heads == 1:
        if has_trtllm:
            logger.warning_once(
                "TRTLLM attention does not support num_kv_heads=1. "
                "This configuration causes TMA descriptor building to fail due to "
                "degenerate tensor strides. Falling back to FlashInfer attention."
            )
        return False

    return has_trtllm and (num_qo_heads % num_kv_heads == 0)

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Dec 17, 2025
@houseroad
Copy link
Copy Markdown
Collaborator

Or wondering if we can support num_kv_heads=1 in FlashInfer trtllm kernel, cc: @yzh119

@yeqcharlotte yeqcharlotte enabled auto-merge (squash) December 17, 2025 07:45
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 17, 2025
@yeqcharlotte
Copy link
Copy Markdown
Collaborator Author

we run into the problem on a smaller debug model run. probably normal sized model wouldn't run into these issues. cc: @pavanimajety @mgoin if the change makes sense to you.

@vllm-bot vllm-bot merged commit a100152 into vllm-project:main Dec 17, 2025
49 of 52 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Dec 17, 2025
@nvpohanh
Copy link
Copy Markdown
Contributor

@yeqcharlotte Could you file a FlashInfer GitHub issue https://github.com/flashinfer-ai/flashinfer with repro steps so that we can investigate this issue? Our expectation is that the trtllm attention kernel should support num_kv_heads=1 (namely, MQA). We have tested various MQA tests and it worked for us, so we want to fix this.

@pavanimajety
Copy link
Copy Markdown
Collaborator

@yeqcharlotte Did you see any log for the failed TMADescriptor? From this line in Flashinfer -
https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fmha/kernelParams.h#L528-L543

@yeqcharlotte
Copy link
Copy Markdown
Collaborator Author

@yeqcharlotte Did you see any log for the failed TMADescriptor? From this line in Flashinfer - https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fmha/kernelParams.h#L528-L543

@pavanimajety included it in the summary section:

Error: Failed to initialize the TMA descriptor due to invalid argument
tmaFormat: 9 dim: 4 gmem: 0xf64d740000
Shape: 64 16 1 7721128 7149852580857340533
Stride: 128 2 4096 337
tileShapes: 64 16 1 1 1024
tileStrides: 1 1 1 1 1969767282
swizzleType: 3

NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Dec 17, 2025
@yzh119
Copy link
Copy Markdown

yzh119 commented Dec 17, 2025

Looks like a misconfiguration of tma descriptor, the stride 1969767282 looks suspicious to me.

@nvpohanh
Copy link
Copy Markdown
Contributor

@yeqcharlotte This PR broke GPT-OSS TP8 fuctionally (see #30919 ). Is it possible to revert this PR for now and apply this patch in your dev branch locally, until FlashInfer team finds out the root cause so that we can apply a more narrow check? The check num_kv_heads=1 is too broad and would affect many models. Among them, GPT-OSS is a very popular model and we don't want users to have bad user experience.

cc @mgoin

Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…#30842)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
shyeh25 added a commit to shyeh25/vllm that referenced this pull request Jan 2, 2026
…-project#30842)"

This reverts commit a100152.

This PR causes GPT-OSS-120B TP8 has functional issue(NotImplementedError: FlashInfer backend currently does not support attention sinks).

Signed-off-by: shyeh25 <206795756+shyeh25@users.noreply.github.com>
shyeh25 added a commit to shyeh25/vllm that referenced this pull request Jan 9, 2026
…-project#30842)"

This reverts commit a100152.

This PR causes GPT-OSS-120B TP8 has functional issue(NotImplementedError: FlashInfer backend currently does not support attention sinks).

Signed-off-by: shyeh25 <206795756+shyeh25@users.noreply.github.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…#30842)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants